from read_dataset import read_dataset
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import cvxpy as cp



def real_data(l, seed):
    np.random.seed(seed)
    n,d,X,Y = read_dataset(l)
    XY = np.hstack([X,Y.reshape(-1,1)])
    XY = np.random.permutation(XY)
    X = XY[:,:-1]
    Y = ((XY[:,-1]+1)/2).astype(int)
    assert X.shape==(n,d); Y.shape==(n,1)
    return n,d,torch.Tensor(X),torch.Tensor(Y)


def obj(x, X, Y, n, lambd):
    m = torch.nn.LogSigmoid()
    obj1 = -torch.multiply(Y,m(X@x)) 
    obj2 = -torch.multiply(1-Y,m(-X@x))
    obj = torch.sum(obj1+obj2)/n
    reg = lambd*torch.linalg.norm(x)**2
    return obj + reg

def gradient(x, X, Y, n, lambd):
    c1 = torch.diag((torch.sigmoid(X@x)-Y))
    obj_g = torch.mm(c1, X)
    obj_g = obj_g.sum(axis=0)
    reg_g = 2*lambd*x 
    return obj_g.squeeze()/n + reg_g

def hessian(x, X, n, lambd, d):
    obj1 = torch.multiply(torch.sigmoid(X@x),torch.sigmoid(-X@x))
    c1 = torch.mm(torch.diag(obj1), X)
    obj_h = torch.mm(X.t(), c1)
    return obj_h/n + 2*lambd*torch.eye(d)

def sample_sketch(m, d, seed):
    torch.random.manual_seed(seed)
    ds = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(d), (1/m)*torch.eye(d))
    S = ds.sample((m,))
    assert S.shape == (m,d)
    return S

def compute_sm(H_S, val):
    sketch_dim = H_S.shape[0]
    return torch.trace(torch.inverse(H_S-val*torch.eye(sketch_dim)))/sketch_dim

def find_l_hat(H_S,l):
    search_range = [5*l/12, l] 
    mid_point = sum(search_range)/2
    if compute_sm(H_S,-search_range[0])>=1/l and compute_sm(H_S,-search_range[1])<=1/l:
        mid_point_val = compute_sm(H_S,-mid_point)
        while torch.abs(mid_point_val-1/l)<1e-6:
            if mid_point_val>1/l:
                search_range[0] = mid_point
            else:
                search_range[1] = mid_point 
            mid_point = sum(search_range)/2
            mid_point_val = compute_sm(H_S,-mid_point) 
        return mid_point
    else:
        return 5*l/12 

def debias(H, l, d, seed, q):
    """
    our version of Hessian inverse estimation :
       step 1: compute sketch dimension m
       step 2: compute l_hat
       step 3: sketch and return
    """
    worker_sum = 0
    for j in range(q):
        # step 1
        m = 2
        while m<d:
            S = sample_sketch(m,d,seed+101*j)
            if compute_sm(S@H@S.T,-5*l/12)>1/l:
                break
            else:
                m = 2*m 
        # step 2
        S = sample_sketch(m,d,seed+101*j) 
        if compute_sm(S@H@S.T,0)<=1/l:
            l_hat = 5*l/12 
        else:
            l_hat = find_l_hat(S@H@S.T,l)
        # step 3
        worker_sum += S.T@torch.inverse(S@H@S.T+l_hat*torch.eye(S.shape[0]))@S
    return worker_sum/q

def nodebias(H, l, d, seed, q):
    """
    our version of Hessian inverse estimation :
       step 1: compute sketch dimension m
       step 2: compute l_hat
       step 3: sketch and return
    """
    worker_sum = 0
    for j in range(q):
        m = 2
        while m<d:
            S = sample_sketch(m,d,seed+101*j)
            if compute_sm(S@H@S.T,-5*l/12)>1/l:
                break
            else:
                m = 2*m 
        worker_sum += S.T@torch.inverse(S@H@S.T+l*torch.eye(S.shape[0]))@S
    return worker_sum/q
    
def line_search(x, del_x, dec,  X, Y, n, lambd) -> tuple[float,int]:
    alpha = 1/4
    beta = 1/2
    t = 1
    i = 0
    val_prev = obj(x, X, Y, n, lambd)
    val = obj(x+t*del_x, X, Y, n, lambd)
    dec = alpha*dec
    while (val>val_prev+t*dec).any() and i<10:
        i += 1
        t = beta*t
        val = obj(x+t*del_x, X, Y, n, lambd)
    return t, i

def giant(x, X, lambd, d, q):
    estimate_hessian_inv = 0
    datanum_each_worker = X.shape[0]//q 
    data_counter = 0
    for _ in range(q):
        curr_data = X[data_counter:(data_counter+datanum_each_worker),:]
        data_counter = data_counter+datanum_each_worker
        obj1 = torch.multiply(torch.sigmoid(curr_data@x),torch.sigmoid(-curr_data@x))
        c1 = torch.mm(torch.diag(obj1), curr_data)
        obj_h = torch.mm(curr_data.t(), c1)
        estimate_hessian_inv += torch.inverse(obj_h/datanum_each_worker + 2*lambd*torch.eye(d))
    return estimate_hessian_inv/q

def determinant(x, X, lambd, d, q):
    estimate_hessian_inv = 0
    datanum_each_worker = X.shape[0]//q 
    data_counter = 0
    det = 0
    for _ in range(q):
        curr_data = X[data_counter:(data_counter+datanum_each_worker),:]
        data_counter = data_counter+datanum_each_worker
        obj1 = torch.multiply(torch.sigmoid(curr_data@x),torch.sigmoid(-curr_data@x))
        c1 = torch.mm(torch.diag(obj1), curr_data)
        obj_h = torch.mm(curr_data.t(), c1)
        estimate_hessian_inv += torch.det(obj_h/datanum_each_worker+2*lambd*torch.eye(d))*torch.inverse(obj_h/datanum_each_worker+2*lambd*torch.eye(d))
        det += torch.det(obj_h/datanum_each_worker+2*lambd*torch.eye(d))
    return estimate_hessian_inv/det

def shrink(x,X, lambd, d, q):
    estimate_hessian_inv = 0
    datanum_each_worker = X.shape[0]//q 
    data_counter = 0
    for _ in range(q):
        curr_data = X[data_counter:(data_counter+datanum_each_worker),:]
        data_counter = data_counter+datanum_each_worker
        obj1 = torch.multiply(torch.sigmoid(curr_data@x),torch.sigmoid(-curr_data@x))
        c1 = torch.mm(torch.diag(obj1), curr_data)
        obj_h = torch.mm(curr_data.t(), c1)
        scale = ((obj_h/datanum_each_worker)@torch.inverse(obj_h/datanum_each_worker+2*lambd*torch.eye(d))).diagonal().sum()
        scale = 1/(1-scale/datanum_each_worker)
        estimate_hessian_inv += torch.inverse(scale*obj_h/datanum_each_worker+2*lambd*torch.eye(d))
    return estimate_hessian_inv/q

def dane(x,X,Y,g, lambd, d, q):
    eta = 1
    mu = 0.5
    w_val = 0
    datanum_each_worker = X.shape[0]//q 
    data_counter = 0
    for _ in range(q):
        curr_x = X[data_counter:(data_counter+datanum_each_worker),:].numpy()
        curr_y = Y[data_counter:(data_counter+datanum_each_worker)].numpy()
        data_counter = data_counter+datanum_each_worker
        w = cp.Variable(d)
        term1 = cp.multiply(curr_y, cp.logistic(-curr_x@w))
        term2 = cp.multiply((1-curr_y),cp.logistic(curr_x@w))
        obj1 = cp.sum(term1+term2)/datanum_each_worker
        obj2 = lambd*cp.sum_squares(w)
        c1 = np.diag((1/(1+np.exp(-curr_x@x.numpy()))-curr_y))
        obj_g = c1@curr_x
        obj_g = obj_g.sum(axis=0)
        term3 = (obj_g/datanum_each_worker+2*lambd*x.numpy())-eta*g.numpy()
        loss = obj1 + obj2 - term3@w
        prob = cp.Problem(cp.Minimize(loss+(mu*0.5)*cp.sum_squares(w-x.numpy())))
        prob.solve(solver='CLARABEL')
        w_val += torch.Tensor(w.value)
    return w_val/q

def disco(x,X, lambd, d, q):
    datanum_each_worker = X.shape[0]//q 
    curr_data = X[0:(datanum_each_worker),:]
    obj1 = torch.multiply(torch.sigmoid(curr_data@x),torch.sigmoid(-curr_data@x))
    c1 = torch.mm(torch.diag(obj1), curr_data)
    obj_h = torch.mm(curr_data.t(), c1)
    estimate_hessian_inv = torch.inverse(obj_h/datanum_each_worker + 2*lambd*torch.eye(d))
    return estimate_hessian_inv


def newton_dane_e(x_init, X, Y, n, lambd, d, q, newton_iter=50):
    x = x_init
    stop_reason = 'converged'
    record = []
    record.append((0,obj(x, X, Y, n, lambd)))
    for i in range(newton_iter):
        g = gradient(x, X, Y, n, lambd)
        try:
            x = dane(x,X,Y,g, lambd, d, q)
        except:
            stop_reason = 'max_iter reached'
        record.append((i+1,obj(x, X, Y, n, lambd)))
    print(stop_reason)  
    return (record, x,  i,  stop_reason)



def newton(x_init, X, Y, n, lambd, d, q, seed, method, newton_iter=50, newton_eps=1e-4):
    x = x_init
    search = 0
    stop_reason = 'max_iter reached'
    record = []
    record.append((0,obj(x, X, Y, n, lambd)))
    for i in range(newton_iter):
        g = gradient(x, X, Y, n, lambd)
        if method == 'h':
            h = hessian(x,X, lambd, d, q)
        if method == 'giant':
            h_inv = giant(x,X, lambd, d, q)
        if method == 'determinant':
            h_inv = determinant(x,X, lambd, d, q)
        if method == 'shrinkage':
            h_inv = shrink(x,X, lambd, d, q)
        if method == 'disco':
            h_inv = disco(x,X, lambd, d, q)
        if method == 'debias':
            obj1 = torch.multiply(torch.sigmoid(X@x),torch.sigmoid(-X@x))
            c1 = torch.mm(torch.diag(obj1), X)
            obj_h = torch.mm(X.t(), c1)
            H = obj_h/n
            l = 2*lambd
            h_inv = debias(H, l, d, seed+97*i, q)
        elif method == 'nodebias':
            obj1 = torch.multiply(torch.sigmoid(X@x),torch.sigmoid(-X@x))
            c1 = torch.mm(torch.diag(obj1), X)
            obj_h = torch.mm(X.t(), c1)
            H = obj_h/n
            l = 2*lambd
            h_inv = nodebias(H, l, d, seed+97*i, q)
        del_x = h_inv@g
        dec = g.t()@del_x
        record.append((i+1,obj(x, X, Y, n, lambd)))
        if torch.abs(dec/2) <= newton_eps:
            stop_reason = 'converged'
            break
        t, search_iters = line_search(x,-del_x, dec, X, Y, n, lambd)
        search += search_iters
        x = x-t*del_x    
    return (record, x,  i,  stop_reason)

def get_x_axis(data_set):
    x_max = -1
    length = 0
    for i in data_set: 
        if i[-1][0]>x_max:
            length = len(i)
            x_max = i[-1][0]
    if length == 2:
        return np.linspace(0, x_max, 2)
    return np.linspace(0, x_max, max(length,100))

def interpolate(data, axis, optimals):
    numbers = np.zeros((len(data), len(axis)))
    for i in range(len(data)):
        numbers[i] = get_numbers(data[i], axis) 
    assert numbers.shape == (len(data), len(axis))
    mean = np.quantile(numbers, 0.5, axis=0)
    error_l = np.quantile(numbers, 0.2, axis=0)
    error_u = np.quantile(numbers, 0.8, axis=0)
    return (mean, error_l, error_u) 

def get_numbers(data, axis, optimal=None): 
    data = np.array(torch.Tensor(data))
    if optimal is None:
        y = np.abs(data[:,1]-data[-1,1])/np.abs(data[-1,1])
        if not data[-1,1]<0.5:
            y[0] = np.abs(data[0,1]-0.4833)/np.abs(0.4833)
    else:
        optimal = np.array(torch.Tensor.cpu(optimal))
        y = np.abs(data[:,1]-optimal)/np.abs(optimal)
    res = np.interp(axis,data[:,0],y)
    return res

def plot_multi_realdata(data):
    lambd = 0.01
    newton_debias = []
    newton_nodebias = []
    x_axis = {}
    for i in range(10):
        print(i)
        n,d, X,Y = real_data(data, i)
        x_init = torch.zeros(d,)
        eps = 1e-6
        max_iter=50
        q = 20
        k5 = newton(x_init, X, Y, n, lambd, d, q, newton_iter=max_iter, newton_eps=eps, seed=i, method='debias')
        k6 = newton(x_init, X, Y, n, lambd, d, q, newton_iter=max_iter, newton_eps=eps, seed=i, method='nodebias')  
        newton_debias.append(k5[0])
        newton_nodebias.append(k6[0])
    x_axis['debias'] = get_x_axis(newton_debias)
    x_axis['nodebias'] = get_x_axis(newton_nodebias)
    plot_data = {}
    plot_data['debias'] = interpolate(newton_debias, x_axis['debias'], None)
    plot_data['nodebias'] = interpolate(newton_nodebias, x_axis['nodebias'], None)
    plt.figure(figsize=(100, 100))
    fig, ax = plt.subplots()
    clrs = sns.color_palette("husl", 5)
    clrs[0] = (0.9677975592919913, 0.44127456009157356, 0.5358103155058701)
    clrs[1] = (0.3126890019504329, 0.6928754610296064, 0.1923704830330379)
    clrs[2] = (0.23299120924703914, 0.639586552066035, 0.9260706093977744)
    ax.plot(x_axis['nodebias'], plot_data['nodebias'][0], label='No Debiasing', c=clrs[0])
    ax.fill_between(x_axis['nodebias'], plot_data['nodebias'][1], plot_data['nodebias'][2],alpha=0.3, facecolor=clrs[0])
    ax.plot(x_axis['debias'], plot_data['debias'][0], label='Debiasing (Ours)', c=clrs[1])
    ax.fill_between(x_axis['debias'], plot_data['debias'][1], plot_data['debias'][2],alpha=0.3, facecolor=clrs[1])
    ax.legend(fontsize=18, loc="upper right")
    ax.set_yscale('log')
    plt.xlabel('Newton Steps', fontsize=20)
    plt.title(data+' (logistic regression)', fontsize=20)
    plt.ylabel('Log Optimality Gap', fontsize=20)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)
    plt.tight_layout()
    plt.savefig('uci_logistic-%s.pdf'%(data))




if __name__ == '__main__':
    plot_multi_realdata(sys.argv[1])

